import torch
from torch import nn
from torch.autograd import Function

import sparse_matmul_cuda

class SparseMatmulFunction(Function):
    @staticmethod
    def forward(ctx, x, y, mask):
        if x.is_cuda:
            index = torch.nonzero(mask).to(x.dtype) # 2D long tensor indicating the index, each row in it is the index for x and y
            output = torch.zeros(x.shape[0], x.shape[1], y.shape[1], device=x.device, dtype=x.dtype)
            sparse_matmul_cuda.forward(x.contiguous(), y.contiguous(), index.contiguous(), output.contiguous())
        else:
            raise NotImplementedError('CPU version of sparse mat mul is not implemented!')
        ctx.save_for_backward(x, y, mask)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        # print('output grad:', grad_output)
        x, y, mask = ctx.saved_variables
        grad_x = torch.zeros_like(x)
        grad_y = torch.zeros_like(y)
        sparse_matmul_cuda.backward(grad_output.contiguous(), 
                                grad_x.contiguous(),
                                grad_y.contiguous(),
                                x.contiguous(), 
                                y.contiguous(), 
                                mask.to(x.dtype).contiguous())
        grad_mask = None # no gradient for index
        return grad_x, grad_y, grad_mask


class SparseMatmul(nn.Module):
    def __init__(self):
        super(SparseMatmul, self).__init__()

    def forward(self, x, y, mask):
        """
        performe matrix multipication along the last dimension of input tensor.
        The shape is 3D. But the last dimension should be the same of inputs.
        Args:
            x: first input matrix, B x M x D
            y: second input matrix, B x N x D
            mask: tensor with 1 or 0, indicating whether perform multiplication (1) or not (0) accordication
                to the index of the values. It should be with the shape of M x N.
        """

        assert x.shape[0:-2] == y.shape[0:-2]
        assert x.shape[-1] == y.shape[-1]
        assert mask.shape[-2] == x.shape[-2] and mask.shape[-1] == y.shape[-2], '{}, {}, {}'.format(mask.shape, x.shape, y.shape)

        return SparseMatmulFunction.apply(x, y, mask)




if __name__ == "__main__":
    import time 
    import copy

    B = 1 # *  16
    M =256
    N = 256
    D = 64
    LOOP = 1

    x = torch.nn.Parameter(torch.randn(B, M, D))# .cuda()
    y = torch.nn.Parameter(torch.randn(B, N, D))# .cuda()
    mask = torch.randint(0, 2, (M, N), requires_grad=False)#.cuda()
    
    x_torch = copy.deepcopy(x)
    y_torch = copy.deepcopy(y)
    mask_torch = copy.deepcopy(mask)

    x = x.cuda()
    y = y.cuda()
    x.retain_grad()
    y.retain_grad()
    mask = mask.cuda()
    x_torch = x_torch.cuda()
    y_torch = y_torch.cuda()
    x_torch.retain_grad()
    y_torch.retain_grad()
    mask_torch = mask_torch.cuda()


    sparse_matmul_ = SparseMatmul()
    tic = time.time()
    for i in range(LOOP):
        # print(i)
        out = sparse_matmul_(x, y, mask)
        loss = out.mean()
        loss.backward(retain_graph=True)

    print('time: {}'.format((time.time() - tic)/LOOP))

    tic = time.time()
    for i in range(LOOP):
        # print(i)
        out_torch = torch.einsum('bmd,bnd->bmn', x_torch, y_torch) # B M N
        out_torch *= mask_torch.unsqueeze(dim=0)
        loss_torch = out_torch.mean()
        loss_torch.backward()
    
    print('time: {}'.format((time.time() - tic)/LOOP))

    print('diff: ', torch.allclose(out, out_torch))
    print('diff: ', (out - out_torch).abs().sum())

    print('x grad diff: ', torch.allclose(x.grad, x_torch.grad))
    print('x grad diff: ', (x.grad - x_torch.grad).abs().sum())

    print('y grad diff: ', torch.allclose(y.grad, y_torch.grad))
    print('y grad diff: ', (y.grad - y_torch.grad).abs().sum())
    


